use std::{
    collections::VecDeque,
    io::{Error as IOError, Read, Write, Seek, SeekFrom, ErrorKind}, 
};

use crate::common::{
    symcodec::SymbolCodec,
    bitcodes::VarLenBitCode32, 
    symbol::{SymbolWithin8Bits, SymWithFreqU32},
    bitio::{BitsRead, BitsWrite},
    trees::TrivialNodeFullBinTree
};

pub struct FanoCodec<S> {
    tree: TrivialNodeFullBinTree<SymWithFreqU32<S>>,
    encode_table_arr: Box<[VarLenBitCode32]>,
}

impl<S: SymbolWithin8Bits> FanoCodec<S> {
    pub fn build(mut symbol_freq_data: Vec<SymWithFreqU32<S>>, skip_zero_freq: bool) -> Self {
        struct VecWithSum<S> {
            vec: VecDeque<SymWithFreqU32<S>>,
            sum: u32,
        }

        fn build_tree_rec<S>(mut l: VecWithSum<S>) -> TrivialNodeFullBinTree<SymWithFreqU32<S>> {
            if l.vec.len() < 2 {
                return TrivialNodeFullBinTree::Leave(
                   l.vec.pop_front().unwrap(),
                );
            }

            let half = l.sum as f32 / 2.;
            let mut s_acc = 0f32;
            let mut new_v = VecDeque::new();

            loop {
                let new_s_acc = s_acc + l.vec.front().unwrap().freq as f32;

                if new_s_acc >= half {
                    if new_s_acc - half < half - s_acc {
                        new_v.push_back(l.vec.pop_front().unwrap());
                        s_acc = new_s_acc;
                    }
                    break;
                }
                new_v.push_back(l.vec.pop_front().unwrap());
                s_acc = new_s_acc;
            }
            
            let s_acc = s_acc as u32;

            l.vec.shrink_to_fit();
            l.sum -= s_acc;

            TrivialNodeFullBinTree::Node {
                lc: Box::new(build_tree_rec(l)),
                rc: Box::new(build_tree_rec(VecWithSum {
                    vec: new_v,
                    sum: s_acc,
                })),
            }
        }

        fn walk_tree_rec<S: SymbolWithin8Bits>(
            tree: &TrivialNodeFullBinTree<SymWithFreqU32<S>>,
            table_arr: &mut Vec<VarLenBitCode32>,
            code: u32,
            depth: u8,
        ) {
            match tree {
                TrivialNodeFullBinTree::Node { lc, rc } => {
                    walk_tree_rec(lc, table_arr, code << 1, depth + 1);
                    walk_tree_rec(rc, table_arr, (code << 1) | 1, depth + 1);
                }
                TrivialNodeFullBinTree::Leave(SymWithFreqU32 { sym, .. }) => {
                    table_arr[sym.val_in_u8() as usize] = VarLenBitCode32 {len: depth, code};
                }
            }
        }
        
        symbol_freq_data.sort_by(|x, y| x.freq.partial_cmp(&y.freq).unwrap().reverse());
        
        let mut total_sum = 0u32;
        let mut input_vec_deque = VecDeque::new();
        for x in symbol_freq_data {
            if skip_zero_freq && x.freq == 0 {
                break;
            }
            total_sum += x.freq;
            input_vec_deque.push_back(x);
        }

        let tree = build_tree_rec(VecWithSum {
            vec: input_vec_deque,
            sum: total_sum,
        });
        
        let mut encode_table_arr = vec![VarLenBitCode32 {code: 0, len: 0}; 1 << (S::bits_needed())];
        walk_tree_rec(&tree, &mut encode_table_arr, 0u32, 0u8);

        Self { tree, encode_table_arr: encode_table_arr.into_boxed_slice() }
    }
}

impl<S: SymbolWithin8Bits> SymbolCodec for FanoCodec<S> {
    fn encode<R, W>(
        &self, 
        i_byte_reader: &mut R,
        o_byte_writer: &mut W, 
        symbol_read_limit: u32
    ) -> Result<u32, IOError>
    where
        R: Read, 
        W: Write + Seek
    {
        // encoded file structure:
        // |symbol len in bits|tree data len in bits|payload size in bits|tree data|payload data|
        // 0                  1                     3                   11         ?            EOF

        let nbits = S::bits_needed();

        o_byte_writer.write(&[nbits])?;
        o_byte_writer.write(&[0u8;10])?;

        let mut bitreader = BitsRead::create_with(i_byte_reader);
        let mut bitwriter = BitsWrite::create_with(o_byte_writer);

        /// tree data: prefix traversal, output bit 0 for node, 
        /// output bit 1 and data associated (symbol) for leave
        fn write_tree_struct_rec<S: SymbolWithin8Bits, W:Write>(
            tree: &TrivialNodeFullBinTree<SymWithFreqU32<S>>, 
            bitwriter: &mut BitsWrite<W>
        ) -> Result<u16, IOError> {
            match tree {
                TrivialNodeFullBinTree::Node { lc, rc, .. } =>  {
                    bitwriter.write_bits_32_8_unchecked(0, 1)?;
                    let ln = write_tree_struct_rec(lc, bitwriter)?;
                    let rn = write_tree_struct_rec(rc, bitwriter)?;
                    Ok(ln + rn + 1)
                }
                TrivialNodeFullBinTree::Leave(SymWithFreqU32 { sym, .. }) => {
                    bitwriter.write_bits_32_8_unchecked(1, 1)?;
                    bitwriter.write_bits_32_8_unchecked(sym.val_in_u8() as u32, S::bits_needed())?;
                    Ok(1 + S::bits_needed() as u16)
                }
            }
        }
        let tree_struct_bits = write_tree_struct_rec(&self.tree, &mut bitwriter)?;
        bitwriter.flush_buffered_bits()?;

        let mut passed = 0u32;
        let mut payload_bits = 0u64;
        let mut b = 0u8;
        while passed < symbol_read_limit && bitreader.read_bits_8_8_unchecked(&mut b, nbits)? {
            let VarLenBitCode32 {len, code} = self.encode_table_arr[b as usize];

            bitwriter.write_bits_32_8_unchecked(code, len)?;

            payload_bits += len as u64;

            passed += 1;
        }
        bitwriter.flush_buffered_bits()?;

        o_byte_writer.seek(SeekFrom::Start(1))?;
        o_byte_writer.write(&tree_struct_bits.to_le_bytes())?;
        o_byte_writer.write(&payload_bits.to_le_bytes())?;

        Ok(passed)
    }

    fn decode<R, W>(
        i_byte_reader: &mut R,
        o_byte_writer: &mut W,
    ) -> Result<u32, IOError>
    where
        R: Read + Seek,
        W: Write
    {
        let mut onebyte_buf = [0u8];
        let mut twobyte_buf = [0u8;2];
        let mut eightbyte_buf = [0u8;8];
        i_byte_reader.read_exact(&mut onebyte_buf)?;
        i_byte_reader.read_exact(&mut twobyte_buf)?;
        i_byte_reader.read_exact(&mut eightbyte_buf)?;

        let symbol_bits = onebyte_buf[0];
        let treestruct_bits = u16::from_le_bytes(twobyte_buf);
        let payload_bits = u64::from_le_bytes(eightbyte_buf);
        
        let mut bitsreader = BitsRead::create_with(i_byte_reader);

        fn build_decode_tree_rec<R: Read>(
            bitsreader: &mut BitsRead<R>, symbol_bits: u8
        ) -> Result<(u16, TrivialNodeFullBinTree<u8>), IOError> {
            let mut b = 0u8;
            bitsreader.read_bits_8_8_unchecked(&mut b, 1)?;
            if b == 0 {
                let l = build_decode_tree_rec(bitsreader, symbol_bits)?;
                let r = build_decode_tree_rec(bitsreader, symbol_bits)?;
                Ok((
                    1 + l.0 + r.0, 
                    TrivialNodeFullBinTree::Node {
                        lc: Box::new(l.1), 
                        rc: Box::new(r.1)
                    }
                ))
            } else {
                bitsreader.read_bits_8_8_unchecked(&mut b, symbol_bits)?;
                Ok((
                    1 + symbol_bits as u16, 
                    TrivialNodeFullBinTree::Leave(b)
                ))
            }
        }
        let (tree_bits_read, ref ref_root) = build_decode_tree_rec(&mut bitsreader, symbol_bits)?;

        // integrity check #1
        if tree_bits_read != treestruct_bits {
            return Err(
                IOError::new(
                    ErrorKind::InvalidData,
                    format!("Incorrect or corrupted data! (Mismatched tree data bits: expected {} got {})", 
                        treestruct_bits, tree_bits_read)
                )
            )
        }
        
        let mut bitswriter = BitsWrite::create_with(o_byte_writer);
        let mut ref_tree = ref_root;
        let mut passed = 0u32;
        let mut payload_bits_read = 0u64;
        let mut bbuf = [0u8];
        let mut byte = 0u8;
        /* note: initialized to 0u8 to bootstrap the first byte read */
        let mut check_mask = 0u8;

        while payload_bits_read < payload_bits {
            if check_mask == 0u8 {
                if i_byte_reader.read(&mut bbuf)? < 1 {
                    break;
                }
                byte = bbuf[0];
                check_mask = 0x80u8;
            }
            if let TrivialNodeFullBinTree::Node { lc, rc } = ref_tree {
                ref_tree = if byte & check_mask != 0 { rc } else { lc };
                if let TrivialNodeFullBinTree::Leave (data) = ref_tree {
                    if symbol_bits == 8 {
                        bbuf[0] = *data;
                        bitswriter.write_bytes(&bbuf)?;
                    } else {
                        bitswriter.write_bits_32_8_unchecked(*data as u32, symbol_bits)?;
                    }
                    passed += 1;
                    ref_tree = ref_root;
                }
            } else {
                panic!("Unexpected state: should be a non-leave node!");
            }
            check_mask >>= 1;
            payload_bits_read += 1;
        }

        // integrity check #2
        if payload_bits_read < payload_bits || ref_root as *const _ !=  ref_tree as *const _ {
            return Err(
                IOError::new(
                    ErrorKind::InvalidData,
                    format!("Incorrect or corrupted data!")
                )
            )
        }

        Ok(passed)
    }
}
